{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SageMaker PySpark PCA on Spark and K-Means Clustering on SageMaker MNIST Example\n",
    "\n",
    "1. [Introduction](#Introduction)\n",
    "2. [Setup](#Setup)\n",
    "3. [Loading the Data](#Loading-the-Data)\n",
    "4. [Create a hybrid pipeline with Spark PCA and SageMaker K-Means](#Create-a-hybrid-pipeline-with-Spark-PCA-and-SageMaker-K-Means)\n",
    "5. [Inference](#Inference)\n",
    "6. [Clean-up](#Clean-up)\n",
    "7. [More on SageMaker Spark](#More-on-SageMaker-Spark)\n",
    "\n",
    "## Introduction\n",
    "This notebook will show how to cluster handwritten digits through the SageMaker PySpark library. \n",
    "\n",
    "We will manipulate data through Spark using a SparkSession, and then use the SageMaker Spark library to interact with SageMaker for training and inference. \n",
    "We will create a pipeline consisting of a first step to reduce the dimensionality using Spark MLLib PCA algorithm, followed by the final K-Means clustering step on SageMaker. \n",
    "\n",
    "You can visit SageMaker Spark's GitHub repository at https://github.com/aws/sagemaker-spark to learn more about SageMaker Spark.\n",
    "\n",
    "This notebook was created and tested on an ml.m4.xlarge notebook instance.\n",
    "\n",
    "## Why use Spark MLLib algorithms? \n",
    "\n",
    "The use of Spark MLLib PCA in this notebook is meant to showcase how you can use different pre-processting steps, ranging from data transformers to algorithms, with tools such as Spark MLLib that are well suited for data pre-processing. You can then use SageMaker algorithms and features through the SageMaker-Spark SDK. Here in our case, PCA is in charge of reducing the feature vector as a pre-processing step, and K-Means responsible for clustering the data. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "First, we import the necessary modules and create the `SparkSession` with the SageMaker-Spark dependencies attached. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import boto3\n",
    "\n",
    "from pyspark import SparkContext, SparkConf\n",
    "from pyspark.sql import SparkSession\n",
    "\n",
    "import sagemaker\n",
    "from sagemaker import get_execution_role\n",
    "import sagemaker_pyspark\n",
    "\n",
    "role = get_execution_role()\n",
    "\n",
    "# Configure Spark to use the SageMaker Spark dependency jars\n",
    "jars = sagemaker_pyspark.classpath_jars()\n",
    "\n",
    "classpath = \":\".join(sagemaker_pyspark.classpath_jars())\n",
    "\n",
    "# See the SageMaker Spark Github to learn how to connect to EMR from a notebook instance\n",
    "spark = SparkSession.builder.config(\"spark.driver.extraClassPath\", classpath)\\\n",
    "    .master(\"local[*]\").getOrCreate()\n",
    "    \n",
    "spark"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading the Data\n",
    "\n",
    "Now, we load the MNIST dataset into a Spark Dataframe, which dataset is available in LibSVM format at\n",
    "\n",
    "`s3://sagemaker-sample-data-[region]/spark/mnist/`\n",
    "\n",
    "where `[region]` is replaced with a supported AWS region, such as us-east-1.\n",
    "\n",
    "In order to train and make inferences our input DataFrame must have a column of Doubles (named \"label\" by default) and a column of Vectors of Doubles (named \"features\" by default).\n",
    "\n",
    "Spark's LibSVM DataFrameReader loads a DataFrame already suitable for training and inference.\n",
    "\n",
    "Here, we load into a DataFrame in the SparkSession running on the local Notebook Instance, but you can connect your Notebook Instance to a remote Spark cluster for heavier workloads. Starting from EMR 5.11.0, SageMaker Spark is pre-installed on EMR Spark clusters. For more on connecting your SageMaker Notebook Instance to a remote EMR cluster, please see [this blog post](https://aws.amazon.com/blogs/machine-learning/build-amazon-sagemaker-notebooks-backed-by-spark-in-amazon-emr/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "\n",
    "region = boto3.Session().region_name\n",
    "spark._jsc.hadoopConfiguration().set('fs.s3a.endpoint', 's3.{}.amazonaws.com'.format(region))\n",
    "\n",
    "trainingData = spark.read.format('libsvm')\\\n",
    "    .option('numFeatures', '784')\\\n",
    "    .load('s3a://sagemaker-sample-data-{}/spark/mnist/train/'.format(region))\n",
    "\n",
    "testData = spark.read.format('libsvm')\\\n",
    "    .option('numFeatures', '784')\\\n",
    "    .load('s3a://sagemaker-sample-data-{}/spark/mnist/test/'.format(region))\n",
    "\n",
    "trainingData.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MNIST images are 28x28, resulting in 784 pixels. The dataset consists of images of digits going from 0 to 9, representing 10 classes. \n",
    "\n",
    "In each row:\n",
    "* The `label` column identifies the image's label. For example, if the image of the handwritten number is the digit 5, the label value is 5.\n",
    "* The `features` column stores a vector (`org.apache.spark.ml.linalg.Vector`) of `Double` values. The length of the vector is 784, as each image consists of 784 pixels. Those pixels are the features we will use. \n",
    "\n",
    "\n",
    "\n",
    "As we are interested in clustering the images of digits, the number of pixels represents the feature vector, while the number of classes represents the number of clusters we want to find. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a hybrid pipeline with Spark PCA and SageMaker K-Means\n",
    "To perform the clustering task, we will first running PCA on our feature vector, reducing it to 50 features. Then, we can use K-Means on the result of PCA to apply the final clustering. We will create a **Pipeline** consisting of 2 stages: the PCA stage, and the K-Means stage. \n",
    "\n",
    "In the following example, we run PCA on our Spark cluster, then train and infer using Amazon SageMaker's K-Means on the output column from PCA:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml import Pipeline\n",
    "from pyspark.ml.feature import PCA\n",
    "\n",
    "from sagemaker_pyspark.algorithms import KMeansSageMakerEstimator\n",
    "from sagemaker_pyspark import IAMRole, EndpointCreationPolicy, RandomNamePolicyFactory\n",
    "from sagemaker_pyspark.transformation.serializers import ProtobufRequestRowSerializer\n",
    "\n",
    "# ML pipeline with 2 stages: PCA and K-Means\n",
    "\n",
    "# 1st stage: PCA \n",
    "pcaSparkEstimator = PCA(\n",
    "  inputCol = \"features\",\n",
    "  outputCol = \"projectedFeatures\",\n",
    "  k = 50)\n",
    "\n",
    "# 2nd stage: K-Means on SageMaker\n",
    "kMeansSageMakerEstimator = KMeansSageMakerEstimator(\n",
    "  sagemakerRole = IAMRole(role),\n",
    "  trainingSparkDataFormatOptions = {\"featuresColumnName\": \"projectedFeatures\"}, # use the output column of PCA\n",
    "  requestRowSerializer = ProtobufRequestRowSerializer(featuresColumnName = \"projectedFeatures\"), # use the output column of PCA\n",
    "  trainingInstanceType = \"ml.m4.xlarge\",\n",
    "  trainingInstanceCount = 1,\n",
    "  endpointInstanceType = \"ml.t2.medium\",\n",
    "  endpointInitialInstanceCount = 1,\n",
    "  namePolicyFactory = RandomNamePolicyFactory(\"sparksm-2-\"),\n",
    "  endpointCreationPolicy = EndpointCreationPolicy.CREATE_ON_TRANSFORM \n",
    ")\n",
    "\n",
    "# Set parameters for K-Means\n",
    "kMeansSageMakerEstimator.setFeatureDim(50)\n",
    "kMeansSageMakerEstimator.setK(10)\n",
    "\n",
    "# Define the stages of the Pipeline in order\n",
    "pipelineSparkSM = Pipeline(stages=[pcaSparkEstimator, kMeansSageMakerEstimator])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To put this pipeline back into context, let's look at the below architecture that shows what actually runs on the notebook instance (with Spark) and on SageMaker."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![PCA on Spark and KMeans on SageMaker](img/sagemaker-spark-pca-spark-kmeans-sagemaker-architecture.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we've defined the `Pipeline`, we can call fit on the training data. Please note the below code will take several minutes to run and create all the resources needed for this pipeline. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train\n",
    "pipelineModelSparkSM = pipelineSparkSM.fit(trainingData)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When we call fit on the pipeline, first the DataFrame will run through the PCA algorithm provided by Spark. The result of the PCA run will be output to the `projectedFeatures` column of the DataFrame. Then, the KMeansSageMakerEstimator takes the resulting DataFrame and runs the training on SageMaker using the provided K-Means algorithm. As we used `EndpointCreationPolicy.CREATE_ON_TRANSFORM`, only the training job will run on `fit`. The model and endpoint will be created once we call `transform`. \n",
    "\n",
    "We've introduced new parameters in the KMeansSageMakerEstimator:\n",
    "* `trainingSparkDataFormatOptions = {\"featuresColumnName\": \"projectedFeatures\"}` configures Spark to serialize the \"projectedFeatures\" column for model training\n",
    "* `requestRowSerializer = ProtobufRequestRowSerializer(featuresColumnName = \"projectedFeatures\")` configures the KMeansModel contained within the PipelineModel returned by fit() to infer on the features in the \"projectedFeatures\" column generated by the PCA step\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference\n",
    "\n",
    "Let's use our test data on our pipeline by calling `transform`. Please note the below code will take several minutes to run and create the endpoint needed in order to serve this pipeline.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Run predictions\n",
    "transformedData = pipelineModelSparkSM.transform(testData)\n",
    "transformedData.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "How well did the pipeline perform? Let us display the digits from each of the clusters and manually inspect the results:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from pyspark.sql.types import DoubleType\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import string\n",
    "\n",
    "# Helper function to display a digit\n",
    "def showDigit(img, caption='', xlabel='', subplot=None):\n",
    "    if subplot==None:\n",
    "        _,(subplot)=plt.subplots(1,1)\n",
    "    imgr=img.reshape((28,28))\n",
    "    subplot.axes.get_xaxis().set_ticks([])\n",
    "    subplot.axes.get_yaxis().set_ticks([])\n",
    "    plt.title(caption)\n",
    "    plt.xlabel(xlabel)\n",
    "    subplot.imshow(imgr, cmap='gray')\n",
    "    \n",
    "def displayClusters(data):\n",
    "    images = np.array(data.select(\"features\").cache().take(250))\n",
    "    clusters = data.select(\"closest_cluster\").cache().take(250)\n",
    "\n",
    "    for cluster in range(10):\n",
    "        print('\\n\\n\\nCluster {}:'.format(string.ascii_uppercase[cluster]))\n",
    "        digits = [ img for l, img in zip(clusters, images) if int(l.closest_cluster) == cluster ]\n",
    "        height=((len(digits)-1)//5)+1\n",
    "        width=5\n",
    "        plt.rcParams[\"figure.figsize\"] = (width,height)\n",
    "        _, subplots = plt.subplots(height, width)\n",
    "        subplots=np.ndarray.flatten(subplots)\n",
    "        for subplot, image in zip(subplots, digits):\n",
    "            showDigit(image, subplot=subplot)\n",
    "        for subplot in subplots[len(digits):]:\n",
    "            subplot.axis('off')\n",
    "\n",
    "        plt.show()\n",
    "        \n",
    "displayClusters(transformedData)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clean-up"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we don't need to make any more inferences, now we delete the resources (endpoints, models, configurations, etc):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Delete the resources\n",
    "from sagemaker_pyspark import SageMakerResourceCleanup\n",
    "from sagemaker_pyspark import SageMakerModel\n",
    "\n",
    "def cleanUp(model):\n",
    "    resource_cleanup = SageMakerResourceCleanup(model.sagemakerClient)\n",
    "    resource_cleanup.deleteResources(model.getCreatedResources())\n",
    "    \n",
    "# Delete the SageMakerModel in pipeline\n",
    "for m in pipelineModelSparkSM.stages:\n",
    "    if isinstance(m, SageMakerModel):\n",
    "        cleanUp(m)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## More on SageMaker Spark\n",
    "\n",
    "The SageMaker Spark Github repository has more about SageMaker Spark, including how to use SageMaker Spark using the Scala SDK: https://github.com/aws/sagemaker-spark\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "notice": "Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.  Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
